{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install allennlp\n",
    "!git clone https://github.com/mhagiwara/realworldnlp.git\n",
    "%cd realworldnlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from typing import Dict, List, Tuple, Set\n",
    "\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from allennlp.common.file_utils import cached_path\n",
    "from allennlp.common.util import START_SYMBOL, END_SYMBOL\n",
    "from allennlp.data.fields import TextField\n",
    "from allennlp.data.instance import Instance\n",
    "from allennlp.data.iterators import BasicIterator\n",
    "from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer\n",
    "from allennlp.data.tokenizers import Token, CharacterTokenizer\n",
    "from allennlp.data.vocabulary import Vocabulary, DEFAULT_PADDING_TOKEN\n",
    "from allennlp.models import Model\n",
    "from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper\n",
    "from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder\n",
    "from allennlp.modules.token_embedders import Embedding\n",
    "from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits\n",
    "from allennlp.training.trainer import Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EMBEDDING_SIZE = 32\n",
    "HIDDEN_SIZE = 256\n",
    "BATCH_SIZE = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_dataset(all_chars: Set[str]=None) -> List[List[Token]]:\n",
    "    \"\"\"Read a plan text file and return character-tokenized sentences.\"\"\"\n",
    "    tokenizer = CharacterTokenizer()\n",
    "    sentences = []\n",
    "    with open(cached_path('https://s3.amazonaws.com/realworldnlpbook/data/tatoeba/sentences.eng.10k.txt')) as f:\n",
    "        for line in f:\n",
    "            line = line.strip()\n",
    "            if not line:\n",
    "                continue\n",
    "            line = re.sub(' +', ' ', line)\n",
    "            tokens = tokenizer.tokenize(line)\n",
    "            if all_chars:\n",
    "                tokens = [token for token in tokens if token.text in all_chars]\n",
    "            sentences.append(tokens)\n",
    "\n",
    "    return sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokens_to_lm_instance(tokens: List[Token],\n",
    "                          token_indexers: Dict[str, TokenIndexer]):\n",
    "    tokens = list(tokens)   # shallow copy\n",
    "    tokens.insert(0, Token(START_SYMBOL))\n",
    "    tokens.append(Token(END_SYMBOL))\n",
    "\n",
    "    input_field = TextField(tokens[:-1], token_indexers)\n",
    "    output_field = TextField(tokens[1:], token_indexers)\n",
    "    return Instance({'input_tokens': input_field,\n",
    "                     'output_tokens': output_field})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNLanguageModel(Model):\n",
    "    def __init__(self,\n",
    "                 embedder: TextFieldEmbedder,\n",
    "                 hidden_size: int,\n",
    "                 max_len: int,\n",
    "                 vocab: Vocabulary) -> None:\n",
    "        super().__init__(vocab)\n",
    "\n",
    "        self.embedder = embedder\n",
    "\n",
    "        self.rnn = PytorchSeq2SeqWrapper(\n",
    "            torch.nn.LSTM(EMBEDDING_SIZE, HIDDEN_SIZE, batch_first=True))\n",
    "\n",
    "        self.hidden2out = torch.nn.Linear(in_features=self.rnn.get_output_dim(),\n",
    "                                          out_features=vocab.get_vocab_size('tokens'))\n",
    "        self.hidden_size = hidden_size\n",
    "        self.max_len = max_len\n",
    "\n",
    "    def forward(self, input_tokens, output_tokens):\n",
    "        mask = get_text_field_mask(input_tokens)\n",
    "        embeddings = self.embedder(input_tokens)\n",
    "        rnn_hidden = self.rnn(embeddings, mask)\n",
    "        out_logits = self.hidden2out(rnn_hidden)\n",
    "        loss = sequence_cross_entropy_with_logits(out_logits, output_tokens['tokens'], mask)\n",
    "\n",
    "        return {'loss': loss}\n",
    "\n",
    "    def generate(self) -> Tuple[List[Token], torch.tensor]:\n",
    "\n",
    "        start_symbol_idx = self.vocab.get_token_index(START_SYMBOL, 'tokens')\n",
    "        end_symbol_idx = self.vocab.get_token_index(END_SYMBOL, 'tokens')\n",
    "        padding_symbol_idx = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN, 'tokens')\n",
    "\n",
    "        log_likelihood = 0.\n",
    "        words = []\n",
    "        state = (torch.zeros(1, 1, self.hidden_size), torch.zeros(1, 1, self.hidden_size))\n",
    "\n",
    "        word_idx = start_symbol_idx\n",
    "\n",
    "        for i in range(self.max_len):\n",
    "            tokens = torch.tensor([[word_idx]])\n",
    "\n",
    "            embeddings = self.embedder({'tokens': tokens})\n",
    "            output, state = self.rnn._module(embeddings, state)\n",
    "            output = self.hidden2out(output)\n",
    "\n",
    "            log_prob = torch.log_softmax(output[0, 0], dim=0)\n",
    "\n",
    "            dist = torch.exp(log_prob)\n",
    "\n",
    "            word_idx = start_symbol_idx\n",
    "\n",
    "            while word_idx in {start_symbol_idx, padding_symbol_idx}:\n",
    "                word_idx = torch.multinomial(\n",
    "                    dist, num_samples=1, replacement=False).item()\n",
    "\n",
    "            log_likelihood += log_prob[word_idx]\n",
    "\n",
    "            if word_idx == end_symbol_idx:\n",
    "                break\n",
    "\n",
    "            token = Token(text=self.vocab.get_token_from_index(word_idx, 'tokens'))\n",
    "            words.append(token)\n",
    "\n",
    "        return words, log_likelihood"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_chars = {END_SYMBOL, START_SYMBOL}\n",
    "all_chars.update(\"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,!?'-\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = read_dataset(all_chars)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_counts = {char: 1 for char in all_chars}\n",
    "vocab = Vocabulary({'tokens': token_counts})\n",
    "\n",
    "token_indexers = {'tokens': SingleIdTokenIndexer()}\n",
    "instances = [tokens_to_lm_instance(tokens, token_indexers)\n",
    "             for tokens in train_set]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),\n",
    "                            embedding_dim=EMBEDDING_SIZE)\n",
    "embedder = BasicTextFieldEmbedder({\"tokens\": token_embedding})\n",
    "\n",
    "model = RNNLanguageModel(embedder=embedder,\n",
    "                         hidden_size=HIDDEN_SIZE,\n",
    "                         max_len=80,\n",
    "                         vocab=vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator = BasicIterator(batch_size=BATCH_SIZE)\n",
    "iterator.index_with(vocab)\n",
    "\n",
    "optimizer = optim.Adam(model.parameters(), lr=5.e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(model=model,\n",
    "                  optimizer=optimizer,\n",
    "                  iterator=iterator,\n",
    "                  train_dataset=instances,\n",
    "                  num_epochs=10)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(text: str, model: Model) -> float:\n",
    "    tokenizer = CharacterTokenizer()\n",
    "    tokens = tokenizer.tokenize(text)\n",
    "    \n",
    "    token_indexers = {'tokens': SingleIdTokenIndexer()}\n",
    "    instance = tokens_to_lm_instance(tokens, token_indexers)\n",
    "    output = model.forward_on_instance(instance)\n",
    "    print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict('The trip to the beach was ruined by bad weather.', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict('The trip to the beach was ruined by bad dogs.', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict('by weather was trip my bad beach the ruined to.', model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(50):\n",
    "    tokens, _ = model.generate()\n",
    "    print(''.join(token.text for token in tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
